import json

import tensorflow as tf
import yaml

from data.preprocess_scripts import *
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
from data.utils import capitalize_and_period

# The dataset without state
DATASET_NAMES_NO_STATE = [
    "nyu_door_opening_surprising_effectiveness",
    "usc_cloth_sim_converted_externally_to_rlds",
    "cmu_franka_exploration_dataset_converted_externally_to_rlds",
    "imperialcollege_sawyer_wrist_cam",
]

# Read the image keys of each dataset
with open("configs/dataset_img_keys.json", "r") as file:
    IMAGE_KEYS = json.load(file)
# Read the config
with open("configs/base.yaml", "r") as file:
    config = yaml.safe_load(file)


def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor:
    """
    Assemble the state/action vector from the arm and base.
    """
    state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
    mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)

    # Assemble the arm state
    arm_concat = tf.cast(arm_concat, tf.float32)
    arm_format = arm_format.split(",")
    # Use the scatter_nd to avoid the duplicate indices
    state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
                                            arm_concat)
    mask_vec = tf.tensor_scatter_nd_update(
        mask_vec,
        [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
        tf.ones(len(arm_format), dtype=tf.float32),
    )

    # Assemble the base state if exists
    if base_concat is not None:
        base_concat = tf.cast(base_concat, tf.float32)
        base_format = base_format.split(",")
        state_vec = tf.tensor_scatter_nd_update(
            state_vec,
            [[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
            base_concat,
        )
        mask_vec = tf.tensor_scatter_nd_update(
            mask_vec,
            [[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
            tf.ones(len(base_format), dtype=tf.float32),
        )
    return state_vec, mask_vec


@tf.autograph.experimental.do_not_convert
def _generate_json_state_agilex(episode: dict, dataset_name: str):
    """
    Generate the json dict and state for a given episode.
    """
    # Load some constants from the config
    IMG_HISTORY_SIZE = config["common"]["img_history_size"]
    if IMG_HISTORY_SIZE < 1:
        raise ValueError("Config `img_history_size` must be at least 1.")
    ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
    if ACTION_CHUNK_SIZE < 1:
        raise ValueError("Config `action_chunk_size` must be at least 1.")

    # Initialize the episode_metadata
    episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}

    # Check whether this episode has an 'END'
    base_act = None
    last_base_act = None
    episode_states = []
    episode_acts = []
    episode_masks = []
    has_base = None
    for step_id, step in enumerate(iter(episode["steps"])):
        # Parse the action
        action = step["action"]
        if has_base is None:
            has_base = "base_concat" in action
        if has_base:
            base_act = action["base_concat"]

        # Parse the state
        state = step["observation"]

        arm_format = state["format"].numpy().decode("utf-8")
        base_format = None
        if has_base:
            act_format = action["format"].numpy().decode("utf-8")
            base_formate_idx = act_format.find("base")
            base_format = act_format[base_formate_idx:]

        arm_state = state["arm_concat"]
        base_state = None
        if has_base:
            if last_base_act is None:
                base_state = base_act * 0
            else:
                base_state = last_base_act
        last_base_act = base_act

        # Assemble the state vector
        state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)

        act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format)

        episode_states.append(state_vec)
        episode_masks.append(mask_vec)
        episode_acts.append(act_vec)

        # Parse the task instruction
        instr = step["observation"]["natural_language_instruction"]
        instr = instr.numpy().decode("utf-8")
        instr = capitalize_and_period(instr)

        # Write to the episode_metadata
        if episode_metadata["instruction"] is None:
            episode_metadata["instruction"] = instr

    episode_metadata["#steps"] = step_id

    episode_states = tf.stack(episode_states)
    episode_masks = tf.stack(episode_masks)
    episode_acts = tf.stack(episode_acts)

    return episode_metadata, episode_states, episode_masks, episode_acts


@tf.autograph.experimental.do_not_convert
def _generate_json_state(episode: dict, dataset_name: str):
    """
    Generate the json dict and state for a given episode.
    """
    # Load some constants from the config
    IMG_HISTORY_SIZE = config["common"]["img_history_size"]
    if IMG_HISTORY_SIZE < 1:
        raise ValueError("Config `img_history_size` must be at least 1.")
    ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
    if ACTION_CHUNK_SIZE < 1:
        raise ValueError("Config `action_chunk_size` must be at least 1.")

    # Initialize the episode_metadata
    episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}

    # Check whether this episode has an 'END'
    base_act = None
    last_base_act = None
    episode_states = []
    episode_masks = []
    has_base = None
    for step_id, step in enumerate(iter(episode["steps"])):
        # Parse the action
        action = step["action"]
        if has_base is None:
            has_base = "base_concat" in action
        if has_base:
            base_act = action["base_concat"]

        # Parse the state
        state = step["observation"]

        arm_format = state["format"].numpy().decode("utf-8")
        base_format = None
        if has_base:
            act_format = action["format"].numpy().decode("utf-8")
            base_formate_idx = act_format.find("base")
            base_format = act_format[base_formate_idx:]

        arm_state = state["arm_concat"]
        base_state = None
        if has_base:
            if last_base_act is None:
                base_state = base_act * 0
            else:
                base_state = last_base_act
        last_base_act = base_act

        # Assemble the state vector
        state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)

        episode_states.append(state_vec)
        episode_masks.append(mask_vec)

        # Parse the task instruction
        instr = step["observation"]["natural_language_instruction"]
        instr = instr.numpy().decode("utf-8")
        instr = capitalize_and_period(instr)

        # Write to the episode_metadata
        if episode_metadata["instruction"] is None:
            episode_metadata["instruction"] = instr

    episode_metadata["#steps"] = step_id
    episode_states = tf.stack(episode_states)
    episode_masks = tf.stack(episode_masks)

    return episode_metadata, episode_states, episode_masks


@tf.autograph.experimental.do_not_convert
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
    """
    Generate the json dict and state for an episode in the dataset without state.
    If not state, we use the last action as current state.
    """
    # Load some constants from the config
    IMG_HISTORY_SIZE = config["common"]["img_history_size"]
    if IMG_HISTORY_SIZE < 1:
        raise ValueError("Config `img_history_size` must be at least 1.")
    ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
    if ACTION_CHUNK_SIZE < 1:
        raise ValueError("Config `action_chunk_size` must be at least 1.")

    # Initialize the episode_metadata
    episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}

    last_base_act = None
    last_arm_act = None
    episode_states = []
    episode_masks = []
    has_base = None
    for step_id, step in enumerate(iter(episode["steps"])):
        # Parse the action
        action = step["action"]
        if has_base is None:
            has_base = "base_concat" in action
        if has_base:
            base_act = action["base_concat"]
            if last_base_act is None:
                last_base_act = base_act * 0  # Initialize

        # Parse the arm action
        arm_act = action["arm_concat"]
        if last_arm_act is None:
            last_arm_act = arm_act * 0  # Initialize

        # Parse the act format
        # Action format as the state format
        act_format = action["format"].numpy().decode("utf-8")

        # Assemble the state vector
        if has_base:
            last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
        else:
            last_act_concat = last_arm_act
        state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format)

        episode_states.append(state_vec)
        episode_masks.append(mask_vec)

        # Parse the task instruction
        instr = step["observation"]["natural_language_instruction"]
        instr = instr.numpy().decode("utf-8")
        instr = capitalize_and_period(instr)

        # Write to the episode_metadata
        if episode_metadata["instruction"] is None:
            episode_metadata["instruction"] = instr

        # Update the last_arm_act and last_base_act
        last_arm_act = arm_act
        if has_base:
            last_base_act = base_act

    episode_metadata["#steps"] = step_id
    episode_states = tf.stack(episode_states)
    episode_masks = tf.stack(episode_masks)

    return episode_metadata, episode_states, episode_masks


@tf.autograph.experimental.do_not_convert
def generate_json_state(episode: dict, dataset_name: str):
    """
    Generate the json dict and state for an episode.
    """
    if isinstance(dataset_name, tf.Tensor):
        dataset_name = dataset_name.numpy().decode("utf-8")

    # Process each step in the episode
    episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, )

    if dataset_name == "agilex":
        return _generate_json_state_agilex(episode, dataset_name)

    if dataset_name in DATASET_NAMES_NO_STATE:
        return _generate_json_state_nostate_ds(episode, dataset_name)

    return _generate_json_state(episode, dataset_name)
